package users

import (
	"context"
	"database/sql"
	"errors"
	"github.com/go-kit/kit/log"
	_ "github.com/go-sql-driver/mysql"
)

type UserRepository interface {
	GetByUsername(ctx context.Context, username string) (*User, error)
	Save(ctx context.Context, user *User) error
	GetRoles(ctx context.Context, id string) ([]string, error)
}

type UserRow struct {
	Id       string
	Username string
	Password string
	Level    sql.NullString
	Name     sql.NullString
	Avatar   sql.NullString
}

func NewUserRow(user *User) *UserRow {
	return &UserRow{
		Id:       *user.Id,
		Username: *user.Username,
		Password: MD5(*user.Password),
		Level:    sql.NullString{String: *user.Level, Valid: true},
	}
}

type RoleRow struct {
	Id   string
	Role string
}

func UserRowToUser(userRow UserRow) *User {
	user := new(User).SetId(userRow.Id).SetUsername(userRow.Username).SetPassword(userRow.Password).SetLevel(userRow.Level.String).SetName(userRow.Name.String).SetAvatar(userRow.Avatar.String)
	return user
}

func NewUserRoleRow(uid string, role string) *RoleRow {
	return &RoleRow{
		Id:   uid,
		Role: role,
	}
}

type defaultUserRepository struct {
	Conn   *sql.DB
	logger log.Logger
}

func NewUserRepository(conn *sql.DB, logger log.Logger) *defaultUserRepository {
	return &defaultUserRepository{
		Conn:   conn,
		logger: logger,
	}
}

func (u *defaultUserRepository) GetByUsername(ctx context.Context, username string) (*User, error) {
	row := u.Conn.QueryRow("SELECT id,username,password,levels,name,avatar FROM user WHERE username=? ", username)
	var userRow UserRow
	err := row.Scan(&userRow.Id, &userRow.Username, &userRow.Password, &userRow.Level, &userRow.Name, &userRow.Avatar)
	if err != nil {
		if err == sql.ErrNoRows {
			//查无此人
			return nil, nil
		} else {
			_ = u.logger.Log("get user scan error:", err)
			return nil, err
		}
	}
	user := UserRowToUser(userRow)
	roles, err := u.GetRoles(ctx, userRow.Id)
	if err != nil {
		return nil, err
	}
	user.Roles = roles
	return user, nil
}

func (u *defaultUserRepository) Save(ctx context.Context, user *User) error {
	userRow := NewUserRow(user)
	//生成用户角色数据列表
	roleRows := make([]*RoleRow, 0)
	for _, role := range user.Roles {
		roleRow := NewUserRoleRow(*user.Id, role)
		roleRows = append(roleRows, roleRow)
	}
	//用户名是否存在
	alreadyUser, err := u.GetByUsername(ctx, *user.Username)
	if err != nil {
		_ = u.logger.Log("get user error", err)
		return errors.New("查询用户是否存在error")
	}
	if alreadyUser != nil {
		return errors.New("用户名已存在")
	}
	//保存用户 该用户所有角色
	//TODO 重复角色
	tx, err := u.Conn.Begin()
	if err != nil {
		return err
	}
	//保存用户
	_, err = tx.Exec("INSERT INTO user(id, username, password, levels) VALUES (?,?,?,?);", &userRow.Id, &userRow.Username, &userRow.Password, &userRow.Level)
	if err != nil {
		if rollbackErr := tx.Rollback(); rollbackErr != nil {
			_ = u.logger.Log("save user rollback error:", err)
		}
		_ = u.logger.Log("insert user error:", err)
		return errors.New("insert user error")
	}
	//保存角色
	stmt, err := tx.Prepare("INSERT INTO role VALUES (?,?);")
	if err != nil {
		return err
	}
	defer func() {
		err := stmt.Close()
		if err != nil {
			_ = u.logger.Log("insert role close error:", err)
		}
	}()
	for _, role := range roleRows {
		_, err := stmt.Exec(&role.Id, &role.Role)
		if err != nil {
			if rollbackErr := tx.Rollback(); rollbackErr != nil {
				_ = u.logger.Log("save user rollback error:", err)
			}
			_ = u.logger.Log("insert user error:", err)
			return errors.New("insert user role error")
		}
	}
	//提交事务
	if err = tx.Commit(); err != nil {
		_ = u.logger.Log("save user commit error:", err)
	}
	return nil
}

func (u *defaultUserRepository) GetRoles(_ context.Context, id string) ([]string, error) {
	rows, err := u.Conn.Query("SELECT id, role FROM role WHERE id=?", id)
	if err != nil {
		_ = u.logger.Log("get user role error:", err)
		return nil, errors.New("用户权限暂时无法获取,请重试")
	}
	defer func() {
		err := rows.Close()
		if err == nil {
			_ = err
		} else {
			_ = u.logger.Log("rows close error:", err)
		}
	}()
	result := make([]RoleRow, 0)
	for rows.Next() {
		var row RoleRow
		err := rows.Scan(&row.Id, &row.Role)
		if err != nil {
			_ = u.logger.Log("get user role scan error:", err)
			return nil, err
		}
		result = append(result, row)
	}
	roles := make([]string, 0)
	for _, row := range result {
		roles = append(roles, row.Role)
	}
	return roles, nil
}
